import tensorflow as tf
import numpy as np
import cv2
import time
# import matplotlib.pyplot as plt

# new a session
sess = tf.InteractiveSession()

# creat placeholder of input data: x and input label: y_
x = tf.placeholder("float",shape = [None,784])
y_ = tf.placeholder("float",shape = [None,10])

# redefine each layers
def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1,1,1,1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1,2,2,1], strides=[1,2,2,1], padding='SAME')

def norm(x):
    return tf.nn.lrn(x, 4, bias=1.0, alpha=0.01/9.0, beta=0.75)

W_conv_1 = tf.Variable(tf.truncated_normal(shape=[5,5,1,32], dtype=tf.float32, name='W_conv_1'))
b_conv_1 = tf.Variable(tf.truncated_normal(shape=[32], dtype=tf.float32, name='b_conv_1'))

W_conv_2 = tf.Variable(tf.truncated_normal(shape=[5,5,32,64], dtype=tf.float32, name='W_conv_2'))
b_conv_2 = tf.Variable(tf.truncated_normal(shape=[64], dtype=tf.float32, name='b_conv_2'))

W_fc_1 = tf.Variable(tf.truncated_normal(shape=[7*7*64,1024], dtype=tf.float32, name='W_fc_1'))
b_fc_1 = tf.Variable(tf.truncated_normal(shape=[1024], dtype=tf.float32, name='b_fc_1'))

W_fc_2 = tf.Variable(tf.truncated_normal(shape=[1024,10], dtype=tf.float32, name='W_fc_2'))
b_fc_2 = tf.Variable(tf.truncated_normal(shape=[10], dtype=tf.float32, name='b_fc_2'))


def cnn_net(x_image):
    h_conv_1 = tf.nn.relu(conv2d(x_image, W_conv_1) + b_conv_1)
    h_pool_1 = max_pool_2x2(h_conv_1)
    # h_norm_1 = norm(h_pool_1)

    h_conv_2 = tf.nn.relu(conv2d(h_pool_1, W_conv_2) + b_conv_2)
    h_pool_2 = max_pool_2x2(h_conv_2)
    # h_norm_2 = norm(h_pool_2)

    h_pool_2_flat = tf.reshape(h_pool_2, [-1,7*7*64])
    h_fc_1 = tf.nn.relu(tf.matmul(h_pool_2_flat, W_fc_1) + b_fc_1)

    y_conv = tf.nn.softmax(tf.matmul(h_fc_1, W_fc_2) + b_fc_2)
    return y_conv

saver = tf.train.Saver()

if __name__ == '__main__':
    saver.restore(sess, "model/cnn_saved/cnn.ckpt")
    print("read model successfully")
    count = 0
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3))
    start_time = time.time()
    for j in range(8):
        for i in range(10):
            dir = 'data/data2/%s.%s.jpg'%(i,j+1)
            img = cv2.imread(dir)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
            img = cv2.GaussianBlur(img,(3,3),0)
            img_array = np.array(img)
            im_data = np.array(np.reshape(img_array, [28, 28]) * 255, dtype=np.float32)

            im_data = im_data/255.0
            x = tf.convert_to_tensor(im_data)
            x = tf.reshape(x, [-1,28,28,1])
            y = cnn_net(x)
            output = list(y.eval(session=sess))[0]
            output = output.tolist()
            # print(output.index(max(output)), ' and ', i)
            if(output.index(max(output)) == i):
                count += 1
    res = count/80.0
    print('use time: %.3f s'%(time.time()-start_time))
    print('accuracy is %.1f%%'%(res*100))